"""
High-resolution map in the Indian Ocean

B. Yadidya
Feb 25, 2026
"""

import sys
sys.path.append('.')
from common_imports import *

dir = "data/"

ds_jason = xr.open_dataset(dir + 'hycom_var_red_jason2_madag.nc')
ds_hycom = xr.open_dataset(dir + 'hycom_ssh_var_red_swot_madag.nc')
ds_diff = xr.open_dataset(dir + 'slope_x_diff_hycom_hret_swot_madag.nc')

fig_width_cm = 18.4
fig_width_in = fig_width_cm / 2.54
fig_height_max_cm = 10.7
fig_height_in = fig_height_max_cm / 2.54 

# 2 rows, 3 columns
fig, axs = plt.subplots(2, 3, figsize=(fig_width_in, fig_height_in), constrained_layout=True,
                        subplot_kw={'projection': ccrs.PlateCarree()})
axes = axs.ravel()  # axes[0] to axes[5]

# Define spatial extent and shared settings
lat_min, lat_max, lon_min, lon_max = -13, -6, 50, 60
cmap = colormaps.prinsenvlag_r
vmin_ssh, vmax_ssh = -20, 20
levels_slope = np.linspace(-50, 50, 21)

# --- 2. Global Map Formatting ---
for i, ax in enumerate(axes):
    ax.add_feature(cfeature.LAND, color='grey', alpha=0.8)
    ax.add_feature(cfeature.COASTLINE, linewidth=.5, edgecolor='grey')
    ax.set_extent([lon_min, lon_max, lat_min, lat_max], crs=ccrs.PlateCarree())
    
    gl = ax.gridlines(draw_labels=True, alpha=0.3)
    gl.top_labels = False
    gl.right_labels = False
    gl.xlabel_style = gl.ylabel_style = {'size': 7}
    
    # Clean up labels: Only show left labels on first column, bottom on second row
    gl.left_labels = True if i % 3 == 0 else False
    gl.bottom_labels = True if i >= 3 else False

# --- 3. ROW 1: SSH VARIANCE (Panels B, C, D) ---

# Panel D: Nadir (JASON-2)
# Re-filter based on the dataset coordinates
for nt_idx in ds_jason.nt.values:
    track = ds_jason.sel(nt=nt_idx)
    # Filter valid points within the box to avoid plotting outside lines
    mask = (track.lat >= lat_min) & (track.lat <= lat_max) & \
           (track.lon >= lon_min) & (track.lon <= lon_max)
    
    if mask.any():
        sc = axes[0].scatter(track.lon[mask], track.lat[mask], 
                            c=track.hycom_vr[mask], s=2, cmap=cmap, 
                            vmin=vmin_ssh, vmax=vmax_ssh, transform=ccrs.PlateCarree())

# Create the inset axis
ax_inset = inset_axes(axes[0], width="30%", height="30%", loc='upper right',
                      bbox_to_anchor=(0, 0, 1, 1), 
                      bbox_transform=axes[0].transAxes,
                      axes_class=GeoAxes, 
                      axes_kwargs=dict(projection=ccrs.PlateCarree(central_longitude=180)))
    
# Add simple geography to inset
ax_inset.add_feature(cfeature.COASTLINE, linewidth=0.4)
ax_inset.add_feature(cfeature.LAND, color='lightgray', alpha=0.6)
    
# Draw the red bounding box representing the study area
box_lons = [lon_min, lon_max, lon_max, lon_min, lon_min]
box_lats = [lat_min, lat_min, lat_max, lat_max, lat_min]
    
ax_inset.fill(box_lons, box_lats, color='red', alpha=0.5, 
             edgecolor='red', linewidth=0.8, transform=ccrs.PlateCarree())
    
# Set to global view
ax_inset.set_global()

madag_passes = ds_hycom.pass_num

# Panel B: SWOT Ascending (Odd)
for p_num in madag_passes:
    if p_num % 2 == 1:
        ds_hycom.fssh_var_hycom_no_eddies.sel(pass_num=p_num).plot(
            x='longitude', y='latitude', cmap=cmap, vmin=vmin_ssh, vmax=vmax_ssh,
            add_colorbar=False, ax=axes[1], transform=ccrs.PlateCarree())

# Panel C: SWOT Descending (Even)
for p_num in madag_passes:
    if p_num % 2 == 0:
        ds_hycom.fssh_var_hycom_no_eddies.sel(pass_num=p_num).plot(
            x='longitude', y='latitude', cmap=cmap, vmin=vmin_ssh, vmax=vmax_ssh,
            add_colorbar=False, ax=axes[2], transform=ccrs.PlateCarree())



# --- 4. ROW 2: SLOPE VARIANCE (Panels E, F, G) ---

for p_num in madag_passes:
    # Panel E: HRET Slope
    ds_diff.fsss_var_x_hret.sel(pass_num=p_num).plot(
        x='longitude', y='latitude', cmap=cmap, levels=levels_slope, 
        add_colorbar=False, ax=axes[3], transform=ccrs.PlateCarree())
    
    # Panel F: HYCOM Slope
    ds_diff.fsss_var_x_hycom.sel(pass_num=p_num).plot(
        x='longitude', y='latitude', cmap=cmap, levels=levels_slope, 
        add_colorbar=False, ax=axes[4], transform=ccrs.PlateCarree())
    
    # Panel G: HYCOM - HRET Difference
    # Note: Using ds_diff variables directly to calculate difference
    diff_data = ds_diff.diff_hycom_hret_var_x.sel(pass_num=p_num)
    p_diff = diff_data.plot(
        x='longitude', y='latitude', cmap=cmap, levels=levels_slope, 
        add_colorbar=False, ax=axes[5], transform=ccrs.PlateCarree())

# --- 5. Annotations and Colorbars ---

titles = ['Nadir Altimeter (JASON-2)', 'SWOT Ascending', 'SWOT Descending', 
          'HRET', 'HYCOM', 'HYCOM - HRET']
labels = ['A', 'B', 'C', 'D', 'E', 'F']

panel_label_props = dict(fontsize=9, fontweight='bold', ha='center', va='center',
                         bbox=dict(boxstyle='round, pad=0.1', facecolor='white', edgecolor='none'))

for i, ax in enumerate(axes):
    ax.set_title(titles[i], fontsize=9, pad=4)
    ax.text(0.05, 0.92, labels[i], transform=ax.transAxes, **panel_label_props)

# Colorbar for SSH (Top Row)
cbar1 = fig.colorbar(sc, ax=axes[:3], location='bottom', shrink=0.35, pad=0.02, aspect = 50, extend = 'both')
cbar1.ax.text(1.06, 0, 'cm$^2$', transform=cbar1.ax.transAxes, va='center', ha='left', fontsize=8)
cbar1.ax.tick_params(which='major', labelsize=7, size=0, width=0)
cbar1.ax.minorticks_off()

# Colorbar for Slope (Bottom Row)
cbar2 = fig.colorbar(p_diff, ax=axes[3:], location='bottom', shrink=0.35, pad=0.02, aspect = 50)
cbar2.ax.text(1.06, 0, '10$^{-3}$(cm/km)$^2$', transform=cbar2.ax.transAxes, va='center', ha='left', fontsize=8)
cbar2.ax.tick_params(which='major', labelsize=7, size=0, width=0)
cbar2.ax.minorticks_off()

plt.savefig('Fig8_jason_swot_hret_hycom_madag.png', dpi=500, bbox_inches='tight')